"""
"""

import json
from copy import deepcopy
from random import sample, shuffle
from typing import Dict, List, Optional, Sequence

import numpy as np
import torch
from cfg import BaseCfg, ModelCfg, TrainCfg  # noqa: F401
from data_reader import CINC2022Reader, PCGDataBase
from inputs import InputConfig, MelSpectrogramInput, MFCCInput, SpectralInput, SpectrogramInput, WaveformInput  # noqa: F401
from torch.utils.data.dataset import Dataset
from tqdm.auto import tqdm

from torch_ecg._preprocessors import PreprocManager
from torch_ecg.cfg import CFG
from torch_ecg.utils.misc import ReprMixin, list_sum
from torch_ecg.utils.utils_data import ensure_siglen, stratified_train_test_split

__all__ = [
    "CinC2022Dataset",
]


class CinC2022Dataset(Dataset, ReprMixin):
    """ """

    __name__ = "CinC2022Dataset"

    def __init__(self, config: CFG, task: str, training: bool = True, lazy: bool = True) -> None:
        """ """
        super().__init__()
        self.config = CFG(deepcopy(config))
        # self.task = task.lower()  # task will be set in self.__set_task
        self.training = training
        self.lazy = lazy

        self.reader = CINC2022Reader(
            self.config.db_dir,
            ignore_unannotated=self.config.get("ignore_unannotated", True),
        )

        self.subjects = self._train_test_split()
        df = self.reader.df_stats[self.reader.df_stats["Patient ID"].isin(self.subjects)]
        self.records = list_sum([self.reader.subject_records[row["Patient ID"]] for _, row in df.iterrows()])
        if self.config.get("entry_test_flag", False):
            self.records = sample(self.records, int(len(self.records) * 0.2))
        if self.training:
            shuffle(self.records)

        if self.config.torch_dtype == torch.float64:
            self.dtype = np.float64
        else:
            self.dtype = np.float32

        ppm_config = CFG(random=False)
        ppm_config.update(deepcopy(self.config.classification))
        seg_ppm_config = CFG(random=False)
        seg_ppm_config.update(deepcopy(self.config.segmentation))
        self.ppm = PreprocManager.from_config(ppm_config)
        self.seg_ppm = PreprocManager.from_config(seg_ppm_config)

        self.__cache = None
        self.__set_task(task, lazy)

    def __len__(self) -> int:
        """ """
        if self.cache is None:
            self._load_all_data()
        return self.cache["waveforms"].shape[0]

    def __getitem__(self, index: int) -> Dict[str, np.ndarray]:
        """ """
        if self.cache is None:
            self._load_all_data()
        return {k: v[index] for k, v in self.cache.items()}

    def __set_task(self, task: str, lazy: bool) -> None:
        """ """
        assert task.lower() in TrainCfg.tasks, f"illegal task \042{task}\042"
        if hasattr(self, "task") and self.task == task.lower() and self.cache is not None and len(self.cache["waveforms"]) > 0:
            return
        self.task = task.lower()
        self.siglen = int(self.config[self.task].fs * self.config[self.task].siglen)
        self.classes = self.config[task].classes
        self.n_classes = len(self.config[task].classes)
        self.lazy = lazy

        if self.task in ["classification"]:
            self.fdr = FastDataReader(self.reader, self.records, self.config, self.task, self.ppm)
        elif self.task in ["segmentation"]:
            self.fdr = FastDataReader(self.reader, self.records, self.config, self.task, self.seg_ppm)
        elif self.task in ["multi_task"]:
            self.fdr = MutiTaskFastDataReader(self.reader, self.records, self.config, self.task, self.ppm)
        else:
            raise ValueError("Illegal task")

        if self.lazy:
            return

        tmp_cache = []
        with tqdm(
            range(len(self.fdr)),
            desc="Loading data",
            unit="records",
            dynamic_ncols=True,
            mininterval=1.0,
        ) as pbar:
            for idx in pbar:
                tmp_cache.append(self.fdr[idx])
        keys = tmp_cache[0].keys()
        self.__cache = {k: np.concatenate([v[k] for v in tmp_cache]) for k in keys}
        for k in keys:
            if self.__cache[k].ndim == 1:
                self.__cache[k] = self.__cache[k]

    def _load_all_data(self) -> None:
        """ """
        self.__set_task(self.task, lazy=False)

    def _train_test_split(self, train_ratio: float = 0.8, force_recompute: bool = False) -> List[str]:
        """ """
        _train_ratio = int(train_ratio * 100)
        _test_ratio = 100 - _train_ratio
        assert _train_ratio * _test_ratio > 0

        train_file = self.reader.db_dir / f"train_ratio_{_train_ratio}.json"
        test_file = self.reader.db_dir / f"test_ratio_{_test_ratio}.json"
        aux_train_file = BaseCfg.project_dir / "utils" / f"train_ratio_{_train_ratio}.json"
        aux_test_file = BaseCfg.project_dir / "utils" / f"test_ratio_{_test_ratio}.json"

        if not force_recompute and train_file.exists() and test_file.exists():
            if self.training:
                return json.loads(train_file.read_text())
            else:
                return json.loads(test_file.read_text())

        if not force_recompute and aux_train_file.exists() and aux_test_file.exists():
            if self.training:
                return json.loads(aux_train_file.read_text())
            else:
                return json.loads(aux_test_file.read_text())

        df_train, df_test = stratified_train_test_split(
            self.reader.df_stats,
            [
                "Murmur",
                "Age",
                "Sex",
                "Pregnancy status",
                "Outcome",
            ],
            test_ratio=1 - train_ratio,
        )

        train_set = df_train["Patient ID"].tolist()
        test_set = df_test["Patient ID"].tolist()

        train_file.write_text(json.dumps(train_set, ensure_ascii=False))
        aux_train_file.write_text(json.dumps(train_set, ensure_ascii=False))
        test_file.write_text(json.dumps(test_set, ensure_ascii=False))
        aux_test_file.write_text(json.dumps(test_set, ensure_ascii=False))

        shuffle(train_set)
        shuffle(test_set)

        if self.training:
            return train_set
        else:
            return test_set

    @property
    def cache(self) -> List[Dict[str, np.ndarray]]:
        return self.__cache

    def extra_repr_keys(self) -> List[str]:
        """ """
        return ["task", "training"]


class FastDataReader(ReprMixin, Dataset):
    """ """

    def __init__(
        self,
        reader: PCGDataBase,
        records: Sequence[str],
        config: CFG,
        task: str,
        ppm: Optional[PreprocManager] = None,
    ) -> None:
        """ """
        self.reader = reader
        self.records = records
        self.config = config
        self.task = task
        self.ppm = ppm
        if self.config.torch_dtype == torch.float64:
            self.dtype = np.float64
        else:
            self.dtype = np.float32

    def __len__(self) -> int:
        """ """
        return len(self.records)

    def __getitem__(self, index: int) -> Dict[str, np.ndarray]:
        """ """
        rec = self.records[index]
        waveforms = self.reader.load_data(
            rec,
            data_format=self.config[self.task].data_format,
        )
        if self.ppm:
            waveforms, _ = self.ppm(waveforms, self.reader.fs)
        waveforms = ensure_siglen(
            waveforms,
            siglen=self.config[self.task].input_len,
            fmt=self.config[self.task].data_format,
            tolerance=self.config[self.task].sig_slice_tol,
        ).astype(self.dtype)
        if waveforms.ndim == 2:
            waveforms = waveforms[np.newaxis, ...]

        n_segments = waveforms.shape[0]

        if self.task in ["classification"]:
            label = self.reader.load_ann(rec)
            if self.config[self.task].loss != "CrossEntropyLoss":
                label = (
                    np.isin(self.config[self.task].classes, label)
                    .astype(self.dtype)[np.newaxis, ...]
                    .repeat(n_segments, axis=0)
                )
            else:
                label = np.array(
                    [self.config[self.task].class_map[label] for _ in range(n_segments)],
                    dtype=int,
                )
            out = {"waveforms": waveforms, "murmur": label}
            if self.config[self.task].outcomes is not None:
                outcome = self.reader.load_outcome(rec)
                if self.config[self.task].loss["outcome"] != "CrossEntropyLoss":
                    outcome = (
                        np.isin(self.config[self.task].outcomes, outcome)
                        .astype(self.dtype)[np.newaxis, ...]
                        .repeat(n_segments, axis=0)
                    )
                else:
                    outcome = np.array(
                        [self.config[self.task].outcome_map[outcome] for _ in range(n_segments)],
                        dtype=int,
                    )
                out["outcome"] = outcome
            return out

        elif self.task in ["segmentation"]:
            label = self.reader.load_segmentation(
                rec,
                seg_format="binary",
                ensure_same_len=True,
                fs=self.config[self.task].fs,
            )
            label = ensure_siglen(
                label,
                siglen=self.config[self.task].input_len,
                fmt="channel_last",
                tolerance=self.config[self.task].sig_slice_tol,
            ).astype(self.dtype)
            return {"waveforms": waveforms, "segmentation": label}
        else:
            raise ValueError(f"Illegal task: {self.task}")

    def extra_repr_keys(self) -> List[str]:
        return [
            "reader",
            "ppm",
        ]


class MutiTaskFastDataReader(ReprMixin, Dataset):
    """ """

    def __init__(
        self,
        reader: PCGDataBase,
        records: Sequence[str],
        config: CFG,
        task: str = "multi_task",
        ppm: Optional[PreprocManager] = None,
    ) -> None:
        """ """
        self.reader = reader
        self.records = records
        self.config = config
        self.task = task
        assert self.task == "multi_task"
        self.ppm = ppm
        if self.config.torch_dtype == torch.float64:
            self.dtype = np.float64
        else:
            self.dtype = np.float32

    def __len__(self) -> int:
        """ """
        return len(self.records)

    def __getitem__(self, index: int) -> Dict[str, np.ndarray]:
        """ """
        rec = self.records[index]
        waveforms = self.reader.load_data(
            rec,
            data_format=self.config[self.task].data_format,
        )
        if self.ppm:
            waveforms, _ = self.ppm(waveforms, self.reader.fs)
        waveforms = ensure_siglen(
            waveforms,
            siglen=self.config[self.task].input_len,
            fmt=self.config[self.task].data_format,
            tolerance=self.config[self.task].sig_slice_tol,
        ).astype(self.dtype)
        if waveforms.ndim == 2:
            waveforms = waveforms[np.newaxis, ...]

        n_segments = waveforms.shape[0]

        label = self.reader.load_ann(rec)
        if self.config[self.task].loss["murmur"] != "CrossEntropyLoss":
            label = (
                np.isin(self.config[self.task].classes, label).astype(self.dtype)[np.newaxis, ...].repeat(n_segments, axis=0)
            )
        else:
            label = np.array(
                [self.config[self.task].class_map[label] for _ in range(n_segments)],
                dtype=int,
            )
        out_tensors = {
            "waveforms": waveforms,
            "murmur": label,
        }

        if self.config[self.task].outcomes is not None:
            outcome = self.reader.load_outcome(rec)
            if self.config[self.task].loss["outcome"] != "CrossEntropyLoss":
                outcome = (
                    np.isin(self.config[self.task].outcomes, outcome)
                    .astype(self.dtype)[np.newaxis, ...]
                    .repeat(n_segments, axis=0)
                )
            else:
                outcome = np.array(
                    [self.config[self.task].outcome_map[outcome] for _ in range(n_segments)],
                    dtype=int,
                )
            out_tensors["outcome"] = outcome

        if self.config[self.task].states is not None:
            mask = self.reader.load_segmentation(
                rec,
                seg_format="binary",
                ensure_same_len=True,
                fs=self.config[self.task].fs,
            )
            mask = ensure_siglen(
                mask,
                siglen=self.config[self.task].input_len,
                fmt="channel_last",
                tolerance=self.config[self.task].sig_slice_tol,
            ).astype(self.dtype)
            out_tensors["segmentation"] = mask

        return out_tensors

    def extra_repr_keys(self) -> List[str]:
        return [
            "reader",
            "ppm",
        ]
